// This file is distributed under a BSD license. See LICENSE.txt for details.

#include "_types.hpp"
#include "_rygdxt.hpp"

/****************************************************************************/

namespace Werkk3TexLib
{

#if sLINK_RYGDXT

// Couple of tables...
static sU8 Expand5[32];
static sU8 Expand6[64];
static sU8 OMatch5[256][2];
static sU8 OMatch6[256][2];
static sU8 QuantRBTab[256+16];
static sU8 QuantGTab[256+16];

/****************************************************************************/

static sInt Mul8Bit(sInt a,sInt b)
{
  sInt t = a*b + 128;
  return (t + (t >> 8)) >> 8;
}

union Pixel
{
  struct
  {
    sU8 b,g,r,a;
  };
  sU32 v;

  void From16Bit(sU16 v)
  {
    sInt rv = (v & 0xf800) >> 11;
    sInt gv = (v & 0x07e0) >>  5;
    sInt bv = (v & 0x001f) >>  0;

    a = 0;
    r = Expand5[rv];
    g = Expand6[gv];
    b = Expand5[bv];
  }

  sU16 As16Bit() const
  {
    return (Mul8Bit(r,31) << 11) + (Mul8Bit(g,63) << 5) + Mul8Bit(b,31);
  }

  void LerpRGB(const Pixel &p1,const Pixel &p2,sInt f)
  {
    r = p1.r + Mul8Bit(p2.r - p1.r,f);
    g = p1.g + Mul8Bit(p2.g - p1.g,f);
    b = p1.b + Mul8Bit(p2.b - p1.b,f);
  }
};

/****************************************************************************/

static void PrepareOptTable(sU8 *Table,const sU8 *expand,sInt size)
{
  for(sInt i=0;i<256;i++)
  {
    sInt bestErr = 256;

    for(sInt min=0;min<size;min++)
    {
      for(sInt max=0;max<size;max++)
      {
        sInt mine = expand[min];
        sInt maxe = expand[max];
        sInt err = sAbs(maxe + Mul8Bit(mine-maxe,0x55) - i);

        if(err < bestErr)
        {
          Table[i*2+0] = max;
          Table[i*2+1] = min;
          bestErr = err;
        }
      }
    }
  }
}

static void EvalColors(Pixel *color,sU16 c0,sU16 c1)
{
  color[0].From16Bit(c0);
  color[1].From16Bit(c1);
  color[2].LerpRGB(color[0],color[1],0x55);
  color[3].LerpRGB(color[0],color[1],0xaa);
}

// Block dithering function. Simply dithers a block to 565 RGB.
// (Floyd-Steinberg)
static void DitherBlock(Pixel *dest,const Pixel *block)
{
  sInt err[8],*ep1 = err,*ep2 = err+4;

  // process channels seperately
  for(sInt ch=0;ch<3;ch++)
  {
    sU8 *bp = (sU8 *) block;
    sU8 *dp = (sU8 *) dest;
    sU8 *quant = (ch == 1) ? QuantGTab+8 : QuantRBTab+8;

    bp += ch;
    dp += ch;
    sSetMem(err,0,sizeof(err));

    for(sInt y=0;y<4;y++)
    {
      // pixel 0
      dp[ 0] = quant[bp[ 0] + ((3*ep2[1] + 5*ep2[0]) >> 4)];
      ep1[0] = bp[ 0] - dp[ 0];

      // pixel 1
      dp[ 4] = quant[bp[ 4] + ((7*ep1[0] + 3*ep2[2] + 5*ep2[1] + ep2[0]) >> 4)];
      ep1[1] = bp[ 4] - dp[ 4];

      // pixel 2
      dp[ 8] = quant[bp[ 8] + ((7*ep1[1] + 3*ep2[3] + 5*ep2[2] + ep2[1]) >> 4)];
      ep1[2] = bp[ 8] - dp[ 8];

      // pixel 3
      dp[12] = quant[bp[12] + ((7*ep1[2] + 5*ep2[3] + ep2[2]) >> 4)];
      ep1[3] = bp[12] - dp[12];

      // advance to next line
      sSwap(ep1,ep2);
      bp += 16;
      dp += 16;
    }
  }
}

// The color matching function
static sU32 MatchColorsBlock(const Pixel *block,const Pixel *color,sBool dither)
{
  sU32 mask = 0;
  sInt dirr = color[0].r - color[1].r;
  sInt dirg = color[0].g - color[1].g;
  sInt dirb = color[0].b - color[1].b;

  sInt dots[16];
  for(sInt i=0;i<16;i++)
    dots[i] = block[i].r*dirr + block[i].g*dirg + block[i].b*dirb;

  sInt stops[4];
  for(sInt i=0;i<4;i++)
    stops[i] = color[i].r*dirr + color[i].g*dirg + color[i].b*dirb;
  
  sInt c0Point = (stops[1] + stops[3]) >> 1;
  sInt halfPoint = (stops[3] + stops[2]) >> 1;
  sInt c3Point = (stops[2] + stops[0]) >> 1;

  if(!dither)
  {
    // the version without dithering is straightforward
    for(sInt i=15;i>=0;i--)
    {
      mask <<= 2;
      sInt dot = dots[i];

      if(dot < halfPoint)
        mask |= (dot < c0Point) ? 1 : 3;
      else
        mask |= (dot < c3Point) ? 2 : 0;
    }
  }
  else
  {
    // with floyd-steinberg dithering (see above)
    sInt err[8],*ep1 = err,*ep2 = err+4;
    sInt *dp = dots;

    c0Point <<= 4;
    halfPoint <<= 4;
    c3Point <<= 4;
    for(sInt i=0;i<8;i++)
      err[i] = 0;

    for(sInt y=0;y<4;y++)
    {
      sInt dot,lmask,step;

      // pixel 0
      dot = (dp[0] << 4) + (3*ep2[1] + 5*ep2[0]);
      if(dot < halfPoint)
        step = (dot < c0Point) ? 1 : 3;
      else
        step = (dot < c3Point) ? 2 : 0;

      ep1[0] = dp[0] - stops[step];
      lmask = step;

      // pixel 1
      dot = (dp[1] << 4) + (7*ep1[0] + 3*ep2[2] + 5*ep2[1] + ep2[0]);
      if(dot < halfPoint)
        step = (dot < c0Point) ? 1 : 3;
      else
        step = (dot < c3Point) ? 2 : 0;

      ep1[1] = dp[1] - stops[step];
      lmask |= step<<2;

      // pixel 2
      dot = (dp[2] << 4) + (7*ep1[1] + 3*ep2[3] + 5*ep2[2] + ep2[1]);
      if(dot < halfPoint)
        step = (dot < c0Point) ? 1 : 3;
      else
        step = (dot < c3Point) ? 2 : 0;

      ep1[2] = dp[2] - stops[step];
      lmask |= step<<4;

      // pixel 3
      dot = (dp[3] << 4) + (7*ep1[2] + 5*ep2[3] + ep2[2]);
      if(dot < halfPoint)
        step = (dot < c0Point) ? 1 : 3;
      else
        step = (dot < c3Point) ? 2 : 0;

      ep1[3] = dp[3] - stops[step];
      lmask |= step<<6;

      // advance to next line
      sSwap(ep1,ep2);
      dp += 4;
      mask |= lmask << (y*8);
    }
  }

  return mask;
}

// The color optimization function. (Clever code, part 1)
static void OptimizeColorsBlock(const Pixel *block,sU16 &max16,sU16 &min16)
{
  static const sInt nIterPower = 4;

  // determine color distribution
  sInt mu[3],min[3],max[3];

  for(sInt ch=0;ch<3;ch++)
  {
    const sU8 *bp = ((const sU8 *) block) + ch;
    sInt muv,minv,maxv;

    muv = minv = maxv = bp[0];
    for(sInt i=4;i<64;i+=4)
    {
      muv += bp[i];
      minv = sMin<sInt>(minv,bp[i]);
      maxv = sMax<sInt>(maxv,bp[i]);
    }

    mu[ch] = (muv + 8) >> 4;
    min[ch] = minv;
    max[ch] = maxv;
  }

  // determine covariance matrix
  sInt cov[6];
  for(sInt i=0;i<6;i++)
    cov[i] = 0;

  for(sInt i=0;i<16;i++)
  {
    sInt r = block[i].r - mu[2];
    sInt g = block[i].g - mu[1];
    sInt b = block[i].b - mu[0];

    cov[0] += r*r;
    cov[1] += r*g;
    cov[2] += r*b;
    cov[3] += g*g;
    cov[4] += g*b;
    cov[5] += b*b;
  }

  // convert covariance matrix to float, find principal axis via power iter
  sF32 covf[6],vfr,vfg,vfb;
  for(sInt i=0;i<6;i++)
    covf[i] = cov[i] / 255.0f;

  vfr = max[2] - min[2];
  vfg = max[1] - min[1];
  vfb = max[0] - min[0];

  for(sInt iter=0;iter<nIterPower;iter++)
  {
    sF32 r = vfr*covf[0] + vfg*covf[1] + vfb*covf[2];
    sF32 g = vfr*covf[1] + vfg*covf[3] + vfb*covf[4];
    sF32 b = vfr*covf[2] + vfg*covf[4] + vfb*covf[5];

    vfr = r;
    vfg = g;
    vfb = b;
  }

  sF32 magn = sMax(sMax(sFAbs(vfr),sFAbs(vfg)),sFAbs(vfb));
  sInt v_r,v_g,v_b;

  if(magn < 4.0f) // too small, default to luminance
  {
    v_r = 148;
    v_g = 300;
    v_b = 58;
  }
  else
  {
    magn = 512.0f / magn;
    v_r = vfr * magn;
    v_g = vfg * magn;
    v_b = vfb * magn;
  }

  // Pick colors at extreme points
  sInt mind = 0x7fffffff,maxd = -0x7fffffff;
  Pixel minp,maxp;

  for(sInt i=0;i<16;i++)
  {
    sInt dot = block[i].r*v_r + block[i].g*v_g + block[i].b*v_b;

    if(dot < mind)
    {
      mind = dot;
      minp = block[i];
    }

    if(dot > maxd)
    {
      maxd = dot;
      maxp = block[i];
    }
  }

  // Reduce to 16 bit colors
  max16 = maxp.As16Bit();
  min16 = minp.As16Bit();
}

// The refinement function. (Clever code, part 2)
// Tries to optimize colors to suit block contents better.
// (By solving a least squares system via normal equations+Cramer's rule)
static sBool RefineBlock(const Pixel *block,sU16 &max16,sU16 &min16,sU32 mask)
{
  static const sInt w1Tab[4] = { 3,0,2,1 };
  static const sInt prods[4] = { 0x090000,0x000900,0x040102,0x010402 };
  // ^some magic to save a lot of multiplies in the accumulating loop...

  sInt akku = 0;
  sInt At1_r,At1_g,At1_b;
  sInt At2_r,At2_g,At2_b;
  sU32 cm = mask;

  At1_r = At1_g = At1_b = 0;
  At2_r = At2_g = At2_b = 0;
  for(sInt i=0;i<16;i++,cm>>=2)
  {
    sInt step = cm&3;
    sInt w1 = w1Tab[step];
    sInt r = block[i].r;
    sInt g = block[i].g;
    sInt b = block[i].b;

    akku    += prods[step];
    At1_r   += w1*r;
    At1_g   += w1*g;
    At1_b   += w1*b;
    At2_r   += r;
    At2_g   += g;
    At2_b   += b;
  }

  At2_r = 3*At2_r - At1_r;
  At2_g = 3*At2_g - At1_g;
  At2_b = 3*At2_b - At1_b;

  // extract solutions and decide solvability
  sInt xx = akku >> 16;
  sInt yy = (akku >> 8) & 0xff;
  sInt xy = (akku >> 0) & 0xff;

  if(!yy || !xx || xx*yy == xy*xy)
    return sFALSE;

  sF32 frb = 3.0f * 31.0f / 255.0f / (xx*yy - xy*xy);
  sF32 fg = frb * 63.0f / 31.0f;

  sU16 oldMin = min16;
  sU16 oldMax = max16;

  // solve.
  max16 =   sRange<sInt>((At1_r*yy - At2_r*xy)*frb+0.5f,31,0) << 11;
  max16 |=  sRange<sInt>((At1_g*yy - At2_g*xy)*fg +0.5f,63,0) << 5;
  max16 |=  sRange<sInt>((At1_b*yy - At2_b*xy)*frb+0.5f,31,0) << 0;

  min16 =   sRange<sInt>((At2_r*xx - At1_r*xy)*frb+0.5f,31,0) << 11;
  min16 |=  sRange<sInt>((At2_g*xx - At1_g*xy)*fg +0.5f,63,0) << 5;
  min16 |=  sRange<sInt>((At2_b*xx - At1_b*xy)*frb+0.5f,31,0) << 0;

  return oldMin != min16 || oldMax != max16;
}

// Color block compression
static void CompressColorBlock(sU8 *dest,const sU32 *src,sInt quality)
{
  const Pixel *block = (const Pixel *) src;
  Pixel dblock[16],color[4];

  // check if block is constant
  sU32 min,max;
  min = max = block[0].v;

  for(sInt i=1;i<16;i++)
  {
    min = sMin(min,block[i].v);
    max = sMax(max,block[i].v);
  }

  // perform block compression
  sU16 min16,max16;
  sU32 mask;

  if(min != max) // no constant color
  {
    // first step: compute dithered version for PCA if desired
    if(quality)
      DitherBlock(dblock,block);

    // second step: pca+map along principal axis
    OptimizeColorsBlock(quality ? dblock : block,max16,min16);
    if(max16 != min16)
    {
      EvalColors(color,max16,min16);
      mask = MatchColorsBlock(block,color,quality != 0);
    }
    else
      mask = 0;

    // third step: refine
    if(RefineBlock(quality ? dblock : block,max16,min16,mask))
    {
      if(max16 != min16)
      {
        EvalColors(color,max16,min16);
        mask = MatchColorsBlock(block,color,quality != 0);
      }
      else
        mask = 0;
    }
  }
  else // constant color
  {
    sInt r = block[0].r;
    sInt g = block[0].g;
    sInt b = block[0].b;

    mask  = 0xaaaaaaaa;
    max16 = (OMatch5[r][0]<<11) | (OMatch6[g][0]<<5) | OMatch5[b][0];
    min16 = (OMatch5[r][1]<<11) | (OMatch6[g][1]<<5) | OMatch5[b][1];
  }

  // write the color block
  if(max16 < min16)
  {
    sSwap(max16,min16);
    mask ^= 0x55555555;
  }

  ((sU16 *) dest)[0] = max16;
  ((sU16 *) dest)[1] = min16;
  ((sU32 *) dest)[1] = mask;
}

// Alpha block compression (this is easy for a change)
static void CompressAlphaBlock(sU8 *dest,const sU32 *src,sInt quality)
{
  const Pixel *block = (const Pixel *) src;

  // find min/max color
  sInt min,max;
  min = max = block[0].a;

  for(sInt i=1;i<16;i++)
  {
    min = sMin<sInt>(min,block[i].a);
    max = sMax<sInt>(max,block[i].a);
  }

  // encode them
  *dest++ = max;
  *dest++ = min;

  // determine bias and emit color indices
  sInt dist = max-min;
  sInt bias = min*7 - (dist >> 1);
  sInt dist4 = dist*4;
  sInt dist2 = dist*2;
  sInt bits = 0,mask=0;
  
  for(sInt i=0;i<16;i++)
  {
    sInt a = block[i].a*7 - bias;
    sInt ind,t;

    // select index (hooray for bit magic)
    t = (dist4 - a) >> 31;  ind =  t & 4; a -= dist4 & t;
    t = (dist2 - a) >> 31;  ind += t & 2; a -= dist2 & t;
    t = (dist - a) >> 31;   ind += t & 1;

    ind = -ind & 7;
    ind ^= (2 > ind);

    // write index
    mask |= ind << bits;
    if((bits += 3) >= 8)
    {
      *dest++ = mask;
      mask >>= 8;
      bits -= 8;
    }
  }
}

/****************************************************************************/

void sInitDXT()
{
  for(sInt i=0;i<32;i++)
    Expand5[i] = (i<<3)|(i>>2);

  for(sInt i=0;i<64;i++)
    Expand6[i] = (i<<2)|(i>>4);

  for(sInt i=0;i<256+16;i++)
  {
    sInt v = sRange(i-8,255,0);
    QuantRBTab[i] = Expand5[Mul8Bit(v,31)];
    QuantGTab[i] = Expand6[Mul8Bit(v,63)];
  }

  PrepareOptTable(&OMatch5[0][0],Expand5,32);
  PrepareOptTable(&OMatch6[0][0],Expand6,64);
}

void sCompressDXTBlock(sU8 *dest,const sU32 *src,sBool alpha,sInt quality)
{
  // if alpha specified, compress alpha aswell
  if(alpha)
  {
    CompressAlphaBlock(dest,src,quality);
    dest += 8;
  }

  // compress the color part
  CompressColorBlock(dest,src,quality);
}

#endif

}

/****************************************************************************/
